
import os

import graphlearning as gl
import time
import random
import numpy as np
import method

import argparse
import pprint
import logging
from tqdm import tqdm

from copy import deepcopy
import torch.multiprocessing as mp
import torch
from sklearn.model_selection import ParameterGrid

def get_logger(args):
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=os.path.join(args.save, 'log.txt'), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)
    return logger


def gaussian_eta(a=0.25, b=1.0):
    return lambda x: np.exp(-x / a) + b

def linear_eta(a=1, b=1):
    return lambda x: a * (np.max(x, axis=-1, keepdims=True) - x) + b

def minus(f, t):
    return lambda x: f(x) - f(x).min(1, keepdims=True) + t
    

parser = argparse.ArgumentParser(description='PyTorch Graph-based Semi-Supervsied Learning')
# dataset
parser.add_argument('-d', '--dataset', metavar='DATASET', default='cifar', type=str)
parser.add_argument('--neighbors', default=10, type=int)
parser.add_argument('-k', '--kernel', metavar='KERNEL', default='gaussian', type=str)
parser.add_argument('-e', '--eta', metavar='ETA', default='gaussian', type=str)
parser.add_argument('--a', default=0.25, type=float)
parser.add_argument('--b', default=1.0, type=float)
parser.add_argument('--minus', default=False, action='store_true')
parser.add_argument('--t', default=0.0, type=float)

parser.add_argument('--num-trials', default=100, type=int)
parser.add_argument('--save', default='./results/baselines/', type=str)
parser.add_argument('--gpu', default=0, type=int)


def main_worker(args, gpu):
    torch.cuda.set_device(gpu)
    if args.method == 'laplace':
        model = gl.ssl.laplace(args.W)
    elif args.method == 'nearest_neighbor':
        model = gl.ssl.graph_nearest_neighbor(args.W)
    elif args.method == 'randomwalk':
        model = gl.ssl.randomwalk(args.W)
    elif args.method == 'wnll':
        model = gl.ssl.laplace(args.W, reweighting='wnll')
    elif args.method == 'centred_kernel':
        model = gl.ssl.centered_kernel(args.W)
    elif args.method == 'plaplace':
        model = gl.ssl.plaplace(args.W)
    elif args.method == 'sparse_lp':
        model = gl.ssl.sparse_label_propagation()
        
    accs = []
    for i in range(args.num_trials):
        random.seed(i)
        np.random.seed(i)
        train_ind = gl.trainsets.generate(args.labels, rate=args.num_labels)
        train_labels = args.labels[train_ind]
        try:
            pred_labels = model.fit_predict(train_ind, train_labels, all_labels=None)
            accuracy = gl.ssl.ssl_accuracy(args.labels, pred_labels, len(train_ind))
            accs.append(accuracy)
        except:
            pass
    return np.array(accs)


if __name__ == "__main__":
    args = parser.parse_args()
    args.save = os.path.join(args.save, '_'.join([str(i) for i in [args.dataset, args.neighbors, args.eta, args.a, args.b, args.minus, args.t]]))
    os.makedirs(args.save, exist_ok=True)
    logger = get_logger(args)
    logger.info('\n' + pprint.pformat(args))
    
    # data loading
    args.labels = gl.datasets.load(args.dataset, labels_only=True)
    eta = eval(args.eta + '_eta')(args.a, args.b)
    eta = minus(eta, args.t) if args.minus else eta
    metric = 'aet' if args.dataset == 'cifar' else 'vae'
    args.W = gl.weightmatrix.knn(args.dataset, args.neighbors, kernel=args.kernel, eta=eta, metric=metric)

    
    params = {
        'a_method': ['laplace', 'nearest_neighbor', 'randomwalk', 'wnll', 'centred_kernel', 'plaplace'],
        'b_num_labels': [1, 2, 3, 4, 5]
    }
    grid = ParameterGrid(params)
    ctx = torch.multiprocessing.get_context("spawn")
    NUM_PROCESSING = 30
    pool = ctx.Pool(NUM_PROCESSING)
    
    pool_list = []
    
    for i, param in enumerate(grid):
        targs = deepcopy(args)
        targs.method = param['a_method']
        targs.num_labels = param['b_num_labels']
        
        res = pool.apply_async(main_worker, args=(targs, i % 5))
        pool_list.append(res)
        
    pool.close()
    pool.join()
    
    results = None
    for param, res in zip(grid, pool_list):
        accs = res.get().reshape(1, -1)
        if results is None:
            results = accs
        else:
            results = np.r_[results, accs]
        # accs = accs[accs>20]
        logger.info(str(param) + ': %.1f (%.1f)' % (accs.mean(), accs.std()))
    
    np.savetxt(os.path.join(args.save, 'results.txt'), results, fmt='%.2f', delimiter=',')
    
